# Tucker DLRT - Pytorch based Dynamical Low-Rank Training for neural networks

This repository implements "Rank-adaptive spectral pruning of convolutional layers during training". Submitted to NeurIPS 2023, do not distribute.

### Demo Suite

The codebase contains the DLRT and Tucker DLRT version for several well-known neural networks and benchmarks.

| Model   | Matrix DLRT | Tucker DLRT | Benchmarks     working                     | Unit-Tested | Model Flag       |
|---------|-------------|-------------|--------------------------------------------|-------------|------------------|
| LeNet5  |             | yes         | MNIST,Fashion MNIST                        | No          | lenet            |
| AlexNet |             | yes         | Cifar10                                    | No          | alexnet          |
| Vgg16   |             | yes         | Cifar10                                    | No          | vgg16            |


### Installation

1. Create a python environment (using conda for example)
2. Install pip in the environment (run ``conda install pip``)
3. Install the project requirements (example for pip):
   ``pip install -r requirements.txt``
4. Run the bash scripts example for the test cases contained in the folder "run_scripts"
    

### Creation of a low-rank network to optimize with DLRT

In order to use DLRT it is necessary to create a custom torch.nn.Module to feed into the optimizer.
It can be done as explained in the following steps:

1. Create a standard torch.nn.Module with the network you need;
2. Use the dlr_module in the wrapper folder with the parameters of your choice to wrap the previous network;
3. Now you have an instance that can be optimized using the custom DLRT optimizer.

### Training

Now the training of this new low-rank network can be performed using the custom DLRT optimizer.

## Example use

```
import torch
from wrapper.dlr_module_wrapper import dlr_module

# EXAMPLE OF PYTORCH MODULE TO WRAP
class Lenet5(torch.nn.Module):
   def __init__(self,device = 'cpu'):
      super(Lenet5, self).__init__()
      self.device = device
      self.layer = torch.nn.Sequential(
            torch.nn.Conv2d(in_channels = 1, out_channels = 20, kernel_size = 5, stride=1),  
            torch.nn.Tanh(),
            torch.nn.MaxPool2d(kernel_size = 2, stride=2),
            torch.nn.Conv2d(in_channels = 20, out_channels = 50, kernel_size = 5, stride=1),  
            torch.nn.Tanh(),
            torch.nn.MaxPool2d(kernel_size = 2, stride=2),
            torch.nn.Flatten(),
            torch.nn.Linear(800,out_features = 500),  
            torch.nn.Tanh(),
            torch.nn.Linear(500,out_features = 10)
      )

   def forward(self, x):
      for layer in self.layer:
            x = layer(x)
      return x


# CREATION OF LOW_RANK NETWORK
tau = 0.2
Lenet5_dlr = DLRTNetwork(Lenet5(), adaptive=adaptive, tucker=tucker,
                            tau={'linear': tau, 'conv2d': tau})

# INTIALIZATION OF THE OPTIMIZER
optimizer = DLRT_Optimizer(f, lr=args.lr, momentum=args.momentum, wd=args.wd)

criterion = torch.nn.CrossEntropyLoss() 
trans = transforms.Compose([transforms.ToTensor(), transforms.Normalize(0,1)])
train_loader = datasets.MNIST(root='./data', train=True, download=True, transform=trans)

# TRAIN LOOP
for i,data in enumerate(train_loader):  # train
                Lenet5_dlr.zero_grad()
                optimizer.zero_grad()
                inputs,labels = data
                inputs,labels = inputs.to(device),labels.to(device)
                def closure():
                    loss = Lenet5_dlr.populate_gradients(inputs,labels,criterion,step = 'S')
                    return loss
                loss,outputs = Lenet5_dlr.populate_gradients(inputs,labels,criterion)
                optimizer.step(closure = closure)

# EVALUATION 
Lenet5_dlr.eval()
evaluation_function()   # your custom evaluation function
```

### LOAD AND SAVE MODELS

To load an already trained model and convert it into low-format it is sufficient to construct the full-rank model with
the loaded weights and then wrap it using the argument 'load_fr_weights' = True. To save and load a low-rank model it is
sufficient to save the torch.model

### Example use

```
# TO LOAD A FULL-RANK MODEL and wrap it
f_full_rank = model().load_state_dict(torch.load('path.pt'))
# or f = torch.load(model,'model_path.pt')
f = DLRTNetwork(f_full_rank, adaptive=adaptive, tucker=tucker,
                        tau={'linear': tau, 'conv2d': tau},**kwargs)   ###wrap the model as you prefer

# TO SAVE AND LOAD A LOW-RANK MODEL

torch.save(f,'path_save.pt')
f = torch.load('path_save.pt')

```



